/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.bigtop.datagenerators.samplers.samplers;
import java.util.Collection;
import java.util.Map;
import java.util.Random;
import org.apache.bigtop.datagenerators.samplers.SeedFactory;
import org.apache.bigtop.datagenerators.samplers.pdfs.MultinomialPDF;
import org.apache.bigtop.datagenerators.samplers.pdfs.ProbabilityDensityFunction;
import org.apache.commons.lang3.tuple.Pair;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
public class RouletteWheelSampler<T> implements Sampler<T>
{
Random rng;
final ImmutableList<Pair<T, Double>> wheel;
public static <T> RouletteWheelSampler<T> create(Map<T, Double> domainWeights, SeedFactory factory)
{
return new RouletteWheelSampler<T>(domainWeights, factory);
}
public static <T> RouletteWheelSampler<T> create(MultinomialPDF<T> pdf, SeedFactory factory)
{
return new RouletteWheelSampler<T>(pdf.getData(), pdf, factory);
}
public static <T> RouletteWheelSampler<T> create(Collection<T> data, ProbabilityDensityFunction<T> pdf, SeedFactory factory)
{
return new RouletteWheelSampler<T>(data, pdf, factory);
}
public static <T> RouletteWheelSampler<T> createUniform(Collection<T> data, SeedFactory factory)
{
Map<T, Double> pdf = Maps.newHashMap();
for(T datum : data)
{
pdf.put(datum, 1.0);
}
return create(pdf, factory);
}
public RouletteWheelSampler(Map<T, Double> domainWeights, SeedFactory factory)
{
this.rng = new Random(factory.getNextSeed());
this.wheel = this.normalize(domainWeights);
}
public RouletteWheelSampler(Collection<T> data, ProbabilityDensityFunction<T> pdf, SeedFactory factory)
{
this.rng = new Random(factory.getNextSeed());
Map<T, Double> domainWeights = Maps.newHashMap();
for(T datum : data)
{
double prob = pdf.probability(datum);
domainWeights.put(datum, prob);
}
this.wheel = this.normalize(domainWeights);
}
private ImmutableList<Pair<T, Double>> normalize(Map<T, Double> domainWeights)
{
double weightSum = 0.0;
for(Map.Entry<T, Double> entry : domainWeights.entrySet())
{
weightSum += entry.getValue();
}
double cumProb = 0.0;
ImmutableList.Builder<Pair<T, Double>> builder = ImmutableList.builder();
for(Map.Entry<T, Double> entry : domainWeights.entrySet())
{
double prob = entry.getValue() / weightSum;
cumProb += prob;
builder.add(Pair.of(entry.getKey(), cumProb));
}
return builder.build();
}
public T sample()
{
double r = rng.nextDouble();
for(Pair<T, Double> cumProbPair : wheel)
if(r < cumProbPair.getValue())
return cumProbPair.getKey();
throw new IllegalStateException("Invalid state -- RouletteWheelSampler should never fail to sample!");
}
}